%matplotlib inline
from matplotlib import pyplot as plt
import torch
import numpy as np
from torchvision import models
# Loading vgg19 model and extracting the feature part
vgg = models.vgg19(pretrained = True).features
#Freezing the parameters in vgg
for param in vgg.parameters():
param.requires_grad_(False)
#move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg.to(device)
Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (17): ReLU(inplace=True) (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (24): ReLU(inplace=True) (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (26): ReLU(inplace=True) (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (31): ReLU(inplace=True) (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (33): ReLU(inplace=True) (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (35): ReLU(inplace=True) (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) )
#download the MS COCO dataset for training
!wget http://images.cocodataset.org/zips/test2017.zip
!mkdir './dataset'
!unzip -q ./test2017.zip -d './dataset'
--2020-11-11 02:08:24-- http://images.cocodataset.org/zips/test2017.zip Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.78.60 Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.78.60|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 6646970404 (6.2G) [application/zip] Saving to: ‘test2017.zip’ test2017.zip 100%[===================>] 6.19G 37.1MB/s in 2m 4s 2020-11-11 02:10:28 (51.3 MB/s) - ‘test2017.zip’ saved [6646970404/6646970404]
from torchvision import datasets
import torchvision.transforms as transforms
#Transform train images
batch_size = 4
num_workers = 0
train_transform = transforms.Compose([
transforms.Resize((264, 264)),
transforms.RandomCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
])
train_data = datasets.ImageFolder('./dataset', transform = train_transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size)
from PIL import Image
import torchvision.transforms as transforms
#Load image
def load_image(img_path):
image = Image.open(img_path).convert('RGB')
#print(image.size)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
])
image = image_transform(image)[:3, :, :].unsqueeze(0)
#print(image.shape)
return image
#load style image
style_image = load_image('great_wave.jpg')
style_image = style_image.to(device)
# Un-normalize image tensors
def denormalize(tensor):
image = tensor.to('cpu').clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
return image
#for name, layer in vgg._modules.items():
#print(name)
#print(layer)
import torch.nn as nn
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
def forward(self, x):
layers = {'3': 'relu1_2',
'8': 'relu2_2',
'17': 'relu3_4',
'22': 'relu4_2',
'26': 'relu4_4', ## content representation
'35': 'relu5_4'}
features = {}
for name, layer in vgg._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
import torch.nn.functional as F
class transformer(nn.Module):
def __init__(self):
super(transformer, self).__init__()
self.conv_block = nn.Sequential(
conv(3, 32, 9, 1),
nn.ReLU(),
conv(32, 64, 3, 2),
nn.ReLU(),
conv(64, 128, 3, 2),
nn.ReLU()
)
self.residual_block = nn.Sequential(
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128)
)
#add relu activation
self.deconv_block = nn.Sequential(
deconv(128, 64, 3, 2, 1),
nn.ReLU(),
deconv(64, 32, 3, 2, 1),
nn.ReLU(),
conv(32, 3, 9, 1, normalize = False)
)
def forward(self, x):
x = self.conv_block(x)
x = self.residual_block(x)
x = self.deconv_block(x)
return x
class conv(nn.Module):
def __init__(self,in_channels, out_channels, kernel_size, stride, normalize = True):
super(conv, self).__init__()
self.reflection_pad = nn.ReflectionPad2d(kernel_size//2)
self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.norm = nn.InstanceNorm2d(out_channels, affine = True) if normalize else None
def forward(self, x):
x = self.reflection_pad(x)
x = self.conv_layer(x)
if self.norm is not None:
x = self.norm(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = conv(channels, channels, 3, 1)
self.conv2 = conv(channels, channels, 3, 1)
def forward(self, x):
in_x = x
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x) +in_x
return x
class deconv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, normalize = True):
super(deconv, self).__init__()
#self.reflection_pad = nn.ReflectionPad2d(kernel_size//2)
padding_size = kernel_size//2
self.deconv_layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride,padding_size, output_padding)
self.norm = nn.InstanceNorm2d(out_channels, affine = True) if normalize else None
def forward(self, x):
#x = self.reflection_pad(x)
x = self.deconv_layer(x)
if self.norm is not None:
x = self.norm(x)
return x
#gram matrix
def gram_matrix(x):
b, d, h, w = x.size()
tensor = x.view(b, d, h*w)
tensor_transpose = tensor.transpose(1, 2)
gram = torch.bmm(tensor,tensor_transpose)/(d*h*w)
return gram
vgg_net = VGG()
transformer_net = transformer().to(device)
#compute style gram matrix
style_features = vgg_net(style_image)
style_gram = { layer:gram_matrix(style_features[layer]) for layer in style_features}
import os
import torch.optim as optim
from PIL import Image
learning_rate = 0.001
optimizer = optim.Adam(transformer_net.parameters(), lr = learning_rate)
criterion = nn.MSELoss().to(device)
epochs = 1
content_weight = 1
style_weight = 12
checkpoint_path = 'checkpoints'
images_path = 'train_results'
os.makedirs(checkpoint_path, exist_ok = True)
os.makedirs(images_path, exist_ok = True)
for epoch in range(epochs):
for batch, (images,_) in enumerate(train_loader):
batch_size = images.shape[0]
#print(batch)
#print(images.shape)
images = images.to(device)
optimizer.zero_grad()
output_images = transformer_net(images)
#print(output_images.shape)
# get features
features = vgg_net(images)
output_features = vgg_net(output_images)
content_loss = content_weight*criterion( output_features['relu2_2'], features['relu2_2'])
style_loss = 0
for layer_name, layer in output_features.items():
gram_f = gram_matrix(layer)
style_loss += criterion(gram_f, style_gram[layer_name][:batch_size])
style_loss *= style_weight
total_loss = content_loss + style_loss
total_loss.backward()
optimizer.step()
if batch%400 == 399 or batch == len(train_loader)-1:
print('Batch {}/{}'.format(batch+1, len(train_loader)))
print('Total loss {}'.format(total_loss.item()))
#get an image
fig, ax = plt.subplots(1, 2)
input_image = images[0].clone().detach()
input_image = denormalize(input_image)
ax[0].imshow(input_image)
ax[0].set_title('Content image')
ax[0].set_xticks([])
ax[0].set_yticks([])
output_image = output_images[0].clone().detach()
output_image = denormalize(output_image)
ax[1].imshow(output_image)
ax[1].set_title('Stylized image')
ax[1].set_xticks([])
ax[1].set_yticks([])
plt.savefig(os.path.join(images_path,'result_{}.png'.format(batch+1)))
plt.show()
torch.save(transformer_net.state_dict(), os.path.join(checkpoint_path,'model_{}'.format(batch+1)))
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 64, 64])) that is different to the input size (torch.Size([4, 64, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 128, 128])) that is different to the input size (torch.Size([4, 128, 128])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 256, 256])) that is different to the input size (torch.Size([4, 256, 256])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 512, 512])) that is different to the input size (torch.Size([4, 512, 512])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 400/10168 Total loss 44.687782287597656
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 800/10168 Total loss 22.06247901916504
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 1200/10168 Total loss 16.27564239501953
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 1600/10168 Total loss 15.917465209960938
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 2000/10168 Total loss 14.430047988891602
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 2400/10168 Total loss 14.091354370117188
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 2800/10168 Total loss 13.166362762451172
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 3200/10168 Total loss 11.666818618774414
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 3600/10168 Total loss 11.936600685119629
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 4000/10168 Total loss 11.357491493225098
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 4400/10168 Total loss 11.756816864013672
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 4800/10168 Total loss 14.453180313110352
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 5200/10168 Total loss 11.594239234924316
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 5600/10168 Total loss 12.628231048583984
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 6000/10168 Total loss 13.654515266418457
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 6400/10168 Total loss 9.585670471191406
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 6800/10168 Total loss 9.375436782836914
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 7200/10168 Total loss 9.322410583496094
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 7600/10168 Total loss 8.270928382873535
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 8000/10168 Total loss 9.798877716064453
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 8400/10168 Total loss 8.695219039916992
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 8800/10168 Total loss 9.502508163452148
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 9200/10168 Total loss 10.261441230773926
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 9600/10168 Total loss 8.499666213989258
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 10000/10168 Total loss 10.684791564941406
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 64, 64])) that is different to the input size (torch.Size([2, 64, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 128, 128])) that is different to the input size (torch.Size([2, 128, 128])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 256, 256])) that is different to the input size (torch.Size([2, 256, 256])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 512, 512])) that is different to the input size (torch.Size([2, 512, 512])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 10168/10168 Total loss 7.845559120178223
#save final model
torch.save(transformer_net.state_dict(),'final_model')
#Test the model
import glob
test_results_path = 'test_results'
os.makedirs(test_results_path, exist_ok = True)
transformer_model = transformer().to(device)
transformer_model.load_state_dict(torch.load('final_model'))
transformer_model.eval()
sample_images_path = list(glob.glob('test_images/*'))
#print(sample_images_path)
for index, sample_image_path in enumerate(sample_images_path):
#print(sample_image_path)
test_image = load_image(sample_image_path)
test_image = test_image.to(device)
test_output = transformer_model(test_image)
test_output = test_output.clone().detach()
stylized_image = denormalize(test_output)
content_image = denormalize(test_image.clone().detach())
fig, ax = plt.subplots(1, 2, figsize = (10, 20))
ax[0].imshow(content_image)
ax[0].set_title('Content_image')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].imshow(stylized_image)
ax[1].set_title('Stylized_image')
ax[1].set_xticks([])
ax[1].set_yticks([])
plt.show()
plt.savefig(os.path.join(test_results_path,'output_{}'.format(index)))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>